package spark.pipeline


import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.{IndexToString, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions.{col, struct, sum}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

object SelectBestKofKMeans {
  def main(args:Array[String]): Unit ={
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)  //  启动这个就不会出现INFO列表,看起来整洁,但是就是可能什么都没有会有空虚感
    val spark = SparkSession.builder().appName("check K of KMeans").master("local").getOrCreate()
    val sc = spark.sparkContext
    import spark.implicits._

    val file = "C:/Users/Lenovo/Desktop/Working/Python/data/kddcup.data_10_percent_corrected.csv"
    val fileData = spark.read.csv(file)
    //        fileData.map(_.split(',').last).countByValue().toSeq.sortBy(_._2).reverse.foreach(println)
    import spark.implicits._

    val dataFrame = fileData
      .toDF("duration", "protocol_type", "service", "flag", "src_bytes", "dst_bytes", "land",
        "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
        "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
        "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
        "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
        "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
        "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
        "dst_host_srv_rerror_rate", "label")

    dataFrame.cache()
    dataFrame.show(25,false)


    //修改"protocol_type"这一列为数值型
    val pretocol_indexer = new StringIndexer().setInputCol("protocol_type").setOutputCol("protocol_typeIndex").fit(dataFrame)
      val indexed_0 = pretocol_indexer.transform(dataFrame)

      //修改"service"这一列为数值型
      val service_indexer = new StringIndexer().setInputCol("service").setOutputCol("serviceIndex").fit(indexed_0)
      val indexed_1 = service_indexer.transform(indexed_0)

      //修改"flag"这一列为数值型
      val flag_indexer = new StringIndexer().setInputCol("flag").setOutputCol("flagIndex").fit(indexed_1)
      val indexed_2 = flag_indexer.transform(indexed_1)

      //修改"label"这一列为数值型
      val label_indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(indexed_2)
      val indexed_df = label_indexer.transform(indexed_2)

//    indexed_df.show(10,false)

    //删除原有的类别列
    val df_final = indexed_df.drop("protocol_type").drop("service")
      .drop("flag").drop("label")

    val assembler = new VectorAssembler()
      .setInputCols(Array("duration", "src_bytes", "dst_bytes", "land",
        "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
        "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
        "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
        "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
        "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
        "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
        "dst_host_srv_rerror_rate", "protocol_typeIndex", "serviceIndex", "flagIndex"))
      .setOutputCol("features")

    val cols = df_final.columns.map(f => col(f).cast(DoubleType))
    val df_finalDataFrame = assembler.transform(df_final.select(cols: _*))

    //        df_finalDataFrame.show(10,false)
    val featuresDataFrame = df_finalDataFrame.select("labelIndex","features")

    val labelBack = new IndexToString().setInputCol("labelIndex")
      .setOutputCol("label").setLabels(indexer_final.labels)

    val scaler=new StandardScaler().
      setInputCol("features").
      setOutputCol("scaledFeatureVector").
      setWithStd(true).
      setWithMean(false)

    val train = labelBack.transform(featuresDataFrame)
//    val trainDF = train
//      .na
//      .replace("*",Map(
//          "warezmaster."->"R2L",
//          "smurf."->"DOS",
//          "pod."->"DOS",
//          "imap."->"R2L",
//          "nmap."->"Probing",
//          "guess_passwd."->"R2L",
//          "ipsweep."->"Probing",
//          "portsweep."->"Probing",
//          "satan."->"Probing",
//          "land."->"DOS",
//          "loadmodule."->"U2R",
//          "ftp_write."->"R2L",
//          "buffer_overflow."->"U2R",
//          "rootkit."->"U2R",
//          "warezclient."->"R2L",
//          "teardrop."->"DOS",
//          "perl."->"U2R",
//          "phf."->"R2L",
//          "multihop."->"R2L",
//          "neptune."->"DOS",
//          "back."->"DOS",
//          "spy."->"R2L"
//      ))
//    trainDF.show(5)

        //More than 3h
      for(k <-10 to 200 by 10){
          var WSSSE: Double = 0
          var avgThreshold:Double = 0
          var avgEuclidean:Double = 0
          for(n <- 1 to 8){
              val Array(trainData, testData) = train.randomSplit(Array(0.7, 0.3))
              val km = new KMeans()
                .setFeaturesCol("scaledFeatureVector")
                .setPredictionCol("prediction")
                .setMaxIter(30)
                .setK(k)


              val pipeline = new Pipeline()
                .setStages(Array(scaler, km))
              val prKmeans = pipeline.fit(trainData)

              val predictionKMeans = prKmeans.transform(testData)

              val evaluator = new ClusteringEvaluator()
                .setPredictionCol("prediction")
                .setFeaturesCol("scaledFeatureVector")


              val kMeansModel=prKmeans.stages.last.asInstanceOf[KMeansModel]
              val centroids=kMeansModel.clusterCenters    //k个中心点
              //        centroids.foreach(println)

              //取出点向量与对应的prediction中心点的距离
              val threshold=predictionKMeans.select("prediction","scaledFeatureVector").as[(Int,Vector)].
                map{ case(cluster,vector)=>Vectors.sqdist(centroids(cluster),vector)}.
                orderBy($"value".desc)

              val thresholdList = threshold.collectAsList()
              val length = threshold.count()
              var sum: Double = 0

              val jessica = thresholdList.takeRight(length.toInt-10)
              jessica.foreach(x=>{
                  sum = sum + x*x
              })

              WSSSE  = WSSSE + (sum/length-10)
              avgThreshold = avgThreshold + threshold.take(100).last
              avgEuclidean = avgEuclidean + evaluator.evaluate(predictionKMeans)
          }
          println("当k = "+k+", 欧氏距离为"+avgEuclidean/8+", 质心距离阈值为:"+avgThreshold/8+", 误差平方和为:"+ WSSSE/8)
      }
  }
}
